{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt \n",
    "import os\n",
    "import io\n",
    "from image_function import *\n",
    "from model import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "main_folder = 'Data\\\\sample3\\\\'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Path_full_size = main_folder + 'pad\\\\'\n",
    "Path_crop = main_folder + 'temp\\\\'\n",
    "Path_predict= main_folder + 'Predict\\\\'\n",
    "N_pieces = 1024 # crop to this number of pieces before upsampling\n",
    "\n",
    "# each image is 1024 x 1024,  it is cropped to 1024 32 x 32 pieces\n",
    "# then upsampled to 256 x 256 for prediciton"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "ensure_dir(Path_crop+'\\\\') # create folder if not exist\n",
    "ensure_dir(Path_predict) # create folder if not exist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict results using the trained U-net model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Step 1: Crop the original images to small pieces, in total there are 52 samples\n",
    "\n",
    "#### Step 2: Read the crop images for one sample,  then upsample them to the shape for the network, in this example, the target size is 256 x 256\n",
    "\n",
    "#### Step 3: Create prediction results of one sample \n",
    "\n",
    "#### Step 4: Stitch it back to a whole images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def testGenerator(test_path,N_slices, N_image_start, N_image_stop,target_size = (256,256),flag_multi_class = False,as_gray = True):\n",
    "    # input:\n",
    "    # test_path: path of the images to read\n",
    "    # N_slices -  number of slice per images\n",
    "    # N_image_start - index of the first image to read\n",
    "    # N_image_stop - index of the last image to read \n",
    "    # target_size - shape of the images for output\n",
    "    file_list = create_filename(N_slices, N_image_start, N_image_stop)\n",
    "    for filename in file_list:\n",
    "        f = os.path.join(test_path,filename)\n",
    "        img =io.imread(f, as_gray = as_gray)\n",
    "\n",
    "        img = img / 255\n",
    "        img = trans.resize(img,target_size)\n",
    "        img = np.reshape(img,img.shape+(1,)) if (not flag_multi_class) else img\n",
    "        img = np.reshape(img,(1,)+img.shape)\n",
    "\n",
    "       # img = np.expand_dim(img, axis = 2)\n",
    "        yield img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def saveResult(save_path,npyfile,flag_multi_class = False,num_class = 2):\n",
    "    for i,item in enumerate(npyfile):\n",
    "        img = labelVisualize(num_class,COLOR_DICT,item) if flag_multi_class else item[:,:,0]\n",
    "        io.imsave(os.path.join(save_path,\"%d_predict.jpg\"%i),img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "images_n = read_crop_write(Path_full_size, Path_crop, N_pieces) \n",
    "# crop full size images to small patches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tf_gpu\\lib\\site-packages\\tensorflow\\python\\framework\\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From C:\\ProgramData\\Anaconda3\\envs\\tf_gpu\\lib\\site-packages\\keras\\backend\\tensorflow_backend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\yizhe\\Desktop\\Yi\\FP03.006Demo\\model.py:67: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor(\"in..., outputs=Tensor(\"co...)`\n",
      "  model = Model(input = inputs, output = outputs)\n"
     ]
    }
   ],
   "source": [
    "model = unet()\n",
    "model.load_weights(\"unet.h5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\ProgramData\\Anaconda3\\envs\\tf_gpu\\lib\\site-packages\\skimage\\transform\\_warps.py:105: UserWarning: The default mode, 'constant', will be changed to 'reflect' in skimage 0.15.\n",
      "  warn(\"The default mode, 'constant', will be changed to 'reflect' in \"\n",
      "C:\\ProgramData\\Anaconda3\\envs\\tf_gpu\\lib\\site-packages\\skimage\\transform\\_warps.py:110: UserWarning: Anti-aliasing will be enabled by default in skimage 0.15 to avoid aliasing artifacts when down-sampling images.\n",
      "  warn(\"Anti-aliasing will be enabled by default in skimage 0.15 to \"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1024/1024 [==============================] - 8s 8ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n",
      "1024/1024 [==============================] - 7s 6ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - ETA:  - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 7s 7ms/step\n",
      "1024/1024 [==============================] - 6s 6ms/step\n"
     ]
    }
   ],
   "source": [
    "for i in range(images_n):\n",
    "    N_image_start = i\n",
    "    N_image_stop = N_image_start\n",
    "    testGene = testGenerator(Path_crop,N_pieces, N_image_start, N_image_stop)\n",
    "    results = model.predict_generator(testGene, N_pieces, verbose = 1)\n",
    "    images = np.squeeze(results)\n",
    "    whole_img = stitch_image (images, N_pieces)\n",
    "    whole_img = trans.resize(whole_img,[1024, 1024])\n",
    "    \n",
    "    filename = Path_predict + str(N_image_start)+'.jpg'\n",
    "    cv2.imwrite(filename,whole_img*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
